import numpy as np
import scipy
from qutip import tensor, identity, destroy, liouvillian, operator_to_vector, spre, expect, commutator, Qobj
from collections import namedtuple

SystemParams = namedtuple('SystemParams',
    ['n_ex', 'n_phon', 'M', 'omega_phon', 'Deltaq', 'g0rtN', 'Omega_p', 'Omega_m', 'Delta_ex', 'Delta_phon', 'Delta_c',
     'Gamma_ex', 'Gamma_phon', 'kappa'])
SolverParams = namedtuple('SolverParams',
    ['root_find', 'eigensolve', 'alpha_guess', 'use_jacobian', 'monodromy_npoints'])


def ss_system_base(system_params):
    (n_ex, n_phon, M, omega_phon, Deltaq, _, Omega_p, Omega_m, Delta_ex, Delta_phon, _, Gamma_ex, Gamma_phon,
     _) = system_params
    Deltaomega = omega_phon + Delta_phon   
    n_blocks = 2*M + 1
    h_dim = (n_ex+1) * (n_phon+1)
    op_dim = h_dim**2
    b = tensor(identity(n_ex + 1), destroy(n_phon + 1))
    X = tensor(destroy(n_ex + 1), identity(n_phon + 1))
    H_mat_no_t = omega_phon*b.dag()*b - Delta_ex*X.dag()*X + omega_phon*Deltaq*(b+b.dag())*X.dag()*X
    H_mat_plus = Omega_p/2.*X + Omega_m/2.*X.dag()
    H_mat_minus = Omega_m/2.*X + Omega_p/2.*X.dag()
    c_ops = [np.sqrt(Gamma_ex)*X, np.sqrt(Gamma_phon)*b]
    system_base = scipy.sparse.kron(np.identity(n_blocks), liouvillian(H_mat_no_t, c_ops).data, format='csc')
    I = scipy.sparse.identity(op_dim)
    for ind_n, n in enumerate(range(M, -M-1, -1)):
        kron_left = np.zeros((n_blocks, n_blocks))
        kron_left[ind_n, ind_n] = 1.
        system_base += scipy.sparse.kron(kron_left, -1.j * n * Deltaomega * I, format='csc')
    system_base += scipy.sparse.kron(np.diag(np.ones(n_blocks-1), k=1), liouvillian(H_mat_plus).data, format='csc')
    system_base += scipy.sparse.kron(np.diag(np.ones(n_blocks-1), k=-1), liouvillian(H_mat_minus).data, format='csc')
    return system_base


def X_coefs(system_params):
    n_ex, n_phon, M = system_params.n_ex, system_params.n_phon, system_params.M
    X = tensor(destroy(n_ex + 1), identity(n_phon + 1))
    LX = liouvillian(X).data
    LXdag = liouvillian(X.dag()).data
    n_blocks = 2*M + 1
    X_coef_a = [None]*n_blocks
    X_coef_ac = [None]*n_blocks
    for ind_n, n in enumerate(range(M, -M-1, -1)):
        template_mat = np.diag(np.ones(n_blocks-abs(n)), k=n)
        X_coef_a[ind_n] = scipy.sparse.kron(template_mat, LXdag, format='csc')
        X_coef_ac[n_blocks - ind_n - 1] = scipy.sparse.kron(template_mat, LX, format='csc')
    return X_coef_a, X_coef_ac


def solve_full(system_params, solver_params):
    (n_ex, n_phon, M, omega_phon, Deltaq, g0rtN, Omega_p, Omega_m, Delta_ex, Delta_phon, Delta_c, Gamma_ex, Gamma_phon,
     kappa) = system_params
    (root_find, eigensolve, alpha_guess, use_jacobian, monodromy_npoints) = solver_params
    Deltaomega = omega_phon + Delta_phon
    h_dim = (n_ex+1) * (n_phon+1)
    op_dim = h_dim ** 2
    n_blocks = 2*M + 1
    ns = np.arange(M, -M-1, -1)
    X = tensor(destroy(n_ex + 1), identity(n_phon + 1))
    b = tensor(identity(n_ex + 1), destroy(n_phon + 1))
    H_mat_0 = omega_phon*b.dag()*b - Delta_ex*X.dag()*X
    H_mat_exphon = omega_phon * Deltaq * (b+b.dag()) * X.dag()*X
    H_mat_plus = Omega_p/2.*X + Omega_m/2.*X.dag()
    H_mat_minus = Omega_m/2.*X + Omega_p/2.*X.dag()
    kron_template = np.identity(n_blocks)
    system_base = ss_system_base(system_params)
    X_coef_a, X_coef_ac = X_coefs(system_params)
    trace_constraint_weight = np.abs(system_base.data.max()) * 1.
    trace_constraint = scipy.sparse.csc_matrix((np.ones(h_dim),
        (M * op_dim * np.ones(h_dim, dtype='int32'),
        [M*op_dim + n*(h_dim+1) for n in range(h_dim)])),
        shape=(n_blocks * op_dim, n_blocks * op_dim))
    system_RHS = np.zeros(system_base.shape[0])
    system_RHS[M*op_dim] = trace_constraint_weight
    I_opbra = operator_to_vector(tensor(identity(n_ex + 1), identity(n_phon + 1))).dag()
    TrX_1 = (I_opbra * spre(X)).data
    TrX = scipy.sparse.kron(kron_template, TrX_1)
    Delta_c_diag = np.diag(-Delta_c + ns*Deltaomega)
    kappa_diag = kappa * np.identity(n_blocks)
    rho_ss = []
    system = None
    def dt_alpha(alpha_split):
        nonlocal rho_ss, system
        system = system_base.copy()
        alpha_n = [a_re + 1.j*a_im for a_re, a_im in zip(alpha_split[:n_blocks], alpha_split[n_blocks:])]
        for a, Xca, Xcastar in zip(alpha_n, X_coef_a, X_coef_ac):
            system += g0rtN*a*Xca + g0rtN*np.conj(a)*Xcastar
        
        # Inspired by QuTiP's LU
        system_LHS = system + trace_constraint_weight*trace_constraint
        lu_solver = scipy.sparse.linalg.splu(system_LHS)
        rhovec_sol = lu_solver.solve(system_RHS)
        rho_ss = []
        for n in range(n_blocks):
            rho = Qobj(rhovec_sol[(n*op_dim):((n+1)*op_dim)].reshape((h_dim, h_dim), order='F'),
                       dims=[[n_ex + 1, n_phon + 1], [n_ex + 1, n_phon + 1]])
            rho_ss.append(rho)
        dt_alpha = [-1.j * ((-Delta_c - 1.j*kappa + n*Deltaomega)*alpha + g0rtN*expect(X, rho))
                    for n, alpha, rho in zip(ns, alpha_n, rho_ss)]
        dt_alpha_split = list(np.real(dt_alpha)) + list(np.imag(dt_alpha))
        if use_jacobian:
            w_plus = np.array([ g0rtN**2 * (TrX @ lu_solver.solve((Xca + Xcastar)@rhovec_sol))
                              for Xca, Xcastar in zip(X_coef_a, X_coef_ac) ]).T
            w_minus = np.array([ g0rtN**2 * (TrX @ lu_solver.solve((Xca - Xcastar)@rhovec_sol))
                              for Xca, Xcastar in zip(X_coef_a, X_coef_ac) ]).T
            jac = np.block([[-kappa_diag - np.imag(w_plus), Delta_c_diag - np.real(w_minus)],
                            [-Delta_c_diag + np.real(w_plus), -kappa_diag - np.imag(w_minus)]])
            return dt_alpha_split, jac
        return dt_alpha_split
    
    if alpha_guess is None:
        alpha_guess = np.zeros(n_blocks, dtype='complex128')
    alpha_split_guess = list(np.real(alpha_guess)) + list(np.imag(alpha_guess))
    if root_find:
        sol = scipy.optimize.root(dt_alpha, alpha_split_guess, method='lm', jac=use_jacobian)
        alpha_ss_split = sol.x
        alpha_ss = np.array([are + 1.j*aim for are, aim in zip(alpha_ss_split[:n_blocks], alpha_ss_split[n_blocks:])])
    else:
        sol = None
        alpha_ss = alpha_guess
        alpha_ss_split = list(np.real(alpha_guess)) + list(np.imag(alpha_guess))
    
    if not eigensolve:
        return sol, rho_ss, alpha_ss, None, None
    T = 2 * np.pi / Deltaomega
    dt = T / monodromy_npoints
    tlist = np.linspace(T - dt/2, dt/2, num=monodromy_npoints)
    monodromy = np.identity(2 + op_dim, dtype='complex128')
    J_aa = scipy.sparse.diags([1.j*Delta_c - kappa])
    J_acac = scipy.sparse.diags([-1.j*Delta_c - kappa])
    J_arho = -1.j * g0rtN * TrX_1
    J_acrho = 1.j * g0rtN * (I_opbra * spre(X.dag())).data
    c_ops = [np.sqrt(Gamma_ex)*X, np.sqrt(Gamma_phon)*b]
    for t in tlist:
        phase_factors = np.exp(1j*ns*Deltaomega*t)
        alpha_ss_t = phase_factors @ alpha_ss
        rho_ss_t = sum([rho * phase_factor for rho, phase_factor in zip(rho_ss, phase_factors)])
        J_rhoa = operator_to_vector(-1.j * g0rtN * commutator(X.dag(), rho_ss_t)).data
        J_rhoac = operator_to_vector(-1.j * g0rtN * commutator(X, rho_ss_t)).data
        H_mat_excav_t = g0rtN * (alpha_ss_t*X.dag() + np.conj(alpha_ss_t)*X)
        H_mat_expump_t = np.exp(1j*Deltaomega*t)*H_mat_plus + np.exp(-1j*Deltaomega*t)*H_mat_minus
        J_rhorho = liouvillian(H_mat_0 + H_mat_exphon + H_mat_excav_t + H_mat_expump_t, c_ops).data
        J = scipy.sparse.bmat(
            [[J_aa, None, J_arho],
             [None, J_acac, J_acrho],
             [J_rhoa, J_rhoac, J_rhorho]])
        monodromy = monodromy @ scipy.linalg.expm((J*dt).toarray())
    eigval, eigvec = np.linalg.eig(monodromy)
    eigval = np.log(eigval) / T
    re_order = np.argsort(-np.real(eigval))
    eigval = eigval[re_order]
    eigvec = eigvec[:, re_order]
    
    return sol, rho_ss, alpha_ss, eigval, eigvec